#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#

import argparse
from itertools import zip_longest, chain
import os
import shutil
import multiprocessing

from fairseq.data import indexed_dataset, dictionary
from fairseq.tokenizer import Tokenizer, tokenize_line


def get_parser():
    parser = argparse.ArgumentParser(
        description='Data pre-processing: Create dictionary and store data in binary format')
    parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', help='source language')
    parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', help='target language')
    parser.add_argument('--trainpref', metavar='FP', default=None, help='target language')
    parser.add_argument('--validpref', metavar='FP', default=None, help='comma separated, valid language prefixes')
    parser.add_argument('--testpref', metavar='FP', default=None, help='comma separated, test language prefixes')
    parser.add_argument('--destdir', metavar='DIR', default='data-bin', help='destination dir')
    parser.add_argument('--thresholdtgt', metavar='N', default=0, type=int,
                        help='map words appearing less than threshold times to unknown')
    parser.add_argument('--thresholdsrc', metavar='N', default=0, type=int,
                        help='map words appearing less than threshold times to unknown')
    parser.add_argument('--tgtdict', metavar='FP', help='reuse given target dictionary')
    parser.add_argument('--srcdict', metavar='FP', help='reuse given source dictionary')
    parser.add_argument('--nwordstgt', metavar='N', default=-1, type=int, help='number of target words to retain')
    parser.add_argument('--nwordssrc', metavar='N', default=-1, type=int, help='number of source words to retain')
    parser.add_argument('--alignfile', metavar='ALIGN', default=None, help='an alignment file (optional)')
    parser.add_argument('--output-format', metavar='FORMAT', default='binary', choices=['binary', 'raw'],
                        help='output format (optional)')
    parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary')
    parser.add_argument('--only-source', action='store_true', help='Only process the source language')
    parser.add_argument('--padding-factor', metavar='N', default=8, type=int,
                        help='Pad dictionary size to be multiple of N')
    parser.add_argument('--reverse-order', action='store_true', help='Reverse source and target sequence')
    return parser


def build_dictionary(filenames):
    d = dictionary.Dictionary()
    for filename in filenames:
        Tokenizer.add_file_to_dictionary(filename, d, tokenize_line)
    return d


def load_dictionary(filename):
    src_dict = dictionary.Dictionary.load(filename)
    return src_dict


def train_paths(lang):
    prefs = args.trainpref.split(',')
    paths = []
    for pref in prefs:
        paths.append('{}{}'.format(pref, ('.' + lang) if lang else ''))
    return paths


def file_name(prefix, lang):
    fname = prefix
    if lang is not None:
        fname += f'.{lang}'
    return fname


def dest_path(prefix, lang):
    return os.path.join(args.destdir, file_name(prefix, lang))


def dict_path(lang):
    return dest_path('dict', lang) + '.txt'


def dataset_dest_path(output_prefix, lang, extension):
    base = f'{args.destdir}/{output_prefix}'
    lang_part = f'.{args.source_lang}-{args.target_lang}.{lang}' if lang is not None else ''
    return f'{base}{lang_part}.{extension}'


def make_binary_dataset(input_prefix, output_prefix, lang, src_ids=None):
    dict = dictionary.Dictionary.load(dict_path(lang))
    print('| [{}] Dictionary: {} types'.format(lang, len(dict)))
    ds = indexed_dataset.IndexedDatasetBuilder(dataset_dest_path(output_prefix, lang, 'bin'))

    def consumer(tensor):
        ds.add_item(tensor)

    input_file = '{}{}'.format(input_prefix, ('.' + lang) if lang is not None else '')
    res, ids = Tokenizer.binarize(input_file, dict, consumer, src_ids=src_ids)
    print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}, {:.3}% replaced by copy'.format(
        lang, input_file, res['nseq'], res['ntok'],
        100 * res['nunk'] / res['ntok'], dict.unk_word, 100 * res['ncopied'] / res['ntok']))
    ds.finalize(dataset_dest_path(output_prefix, lang, 'idx'))

    return ids


def make_dataset(input_prefix, output_prefix, lang, trg_lang):
    if args.output_format == 'binary':
        src_tokens = make_binary_dataset(input_prefix, output_prefix, lang)
        _ = make_binary_dataset(input_prefix, output_prefix, trg_lang, src_tokens)

    # Copy original text file to destination folder
    output_text_file = dest_path(
        output_prefix + '.{}-{}'.format(args.source_lang, args.target_lang),
        lang,
    )
    shutil.copyfile(file_name(input_prefix, lang), output_text_file)

    # Copy original text file to destination folder
    output_text_file_trg = dest_path(
        output_prefix + '.{}-{}'.format(args.source_lang, args.target_lang),
        trg_lang,
    )
    shutil.copyfile(file_name(input_prefix, trg_lang), output_text_file_trg)


def make_all(lang, trg_lang):
    if args.testpref:
        for k, testpref in enumerate(args.testpref.split(',')):
            outprefix = 'test{}'.format(k) if k > 0 else 'test'
            make_dataset(testpref, outprefix, lang, trg_lang)
    if args.trainpref:
        f_args = [(trainpref, 'train{}'.format(k) if k > 0 else 'train', lang, trg_lang)
                  for k, trainpref in enumerate(args.trainpref.split(','))]
        with multiprocessing.Pool(processes=min(8, len(f_args))) as pool:
            pool.starmap(make_dataset, f_args)
    if args.validpref:
        for k, validpref in enumerate(args.validpref.split(',')):
            outprefix = 'valid{}'.format(k) if k > 0 else 'valid'
            make_dataset(validpref, outprefix, lang, trg_lang)


def main(args):
    os.makedirs(args.destdir, exist_ok=True)
    target = not args.only_source

    if args.joined_dictionary:
        assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary'
        assert not args.tgtdict, 'cannot combine --tgtdict and --joined-dictionary'
        src_dict = build_dictionary(set(chain.from_iterable([
            train_paths(lang)
            for lang in [args.source_lang, args.target_lang]
        ])))
        tgt_dict = src_dict
    else:
        assert args.trainpref, "--trainpref must be set"
        if not args.srcdict:
            src_dict = build_dictionary(chain.from_iterable([train_paths(args.source_lang)]))

        if target and not args.tgtdict:
            assert args.trainpref, "--trainpref must be set"
            tgt_dict = build_dictionary(chain.from_iterable([train_paths(args.target_lang)]))

        if args.srcdict:
            src_dict = load_dictionary(args.srcdict)
        if target and args.tgtdict:
            tgt_dict = load_dictionary(args.tgtdict)

    src_dict.finalize(
        threshold=args.thresholdsrc,
        nwords=args.nwordssrc,
        padding_factor=args.padding_factor,
    )
    src_dict.save(dict_path(args.source_lang))
    if target:
        if not args.joined_dictionary:
            tgt_dict.finalize(
                threshold=args.thresholdtgt,
                nwords=args.nwordstgt,
                padding_factor=args.padding_factor,
            )
        tgt_dict.save(dict_path(args.target_lang))

    make_all(args.source_lang, args.target_lang)

    print('| Wrote preprocessed data to {}'.format(args.destdir))

    if args.alignfile:
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
        src_file_name = train_paths(args.source_lang)[0]
        tgt_file_name = train_paths(args.target_lang)[0]
        src_dict = dictionary.Dictionary.load(dict_path(args.source_lang))
        tgt_dict = dictionary.Dictionary.load(dict_path(args.target_lang))
        freq_map = {}
        with open(args.alignfile, 'r') as align_file:
            with open(src_file_name, 'r') as src_file:
                with open(tgt_file_name, 'r') as tgt_file:
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = Tokenizer.tokenize(s, src_dict, add_if_not_exist=False)
                        ti = Tokenizer.tokenize(t, tgt_dict, add_if_not_exist=False)
                        ai = list(map(lambda x: tuple(x.split('-')), a.split()))
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)

        with open(os.path.join(args.destdir, 'alignment.{}-{}.txt'.format(
                args.source_lang, args.target_lang)), 'w') as f:
            for k, v in align_dict.items():
                print('{} {}'.format(src_dict[k], tgt_dict[v]), file=f)


if __name__ == '__main__':
    parser = get_parser()
    args = parser.parse_args()
    main(args)
